#!/bin/bash

# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Llama2-7b
# Additionally, this file serves as integration test for context parallelism for training in TPUs in MaxText

# The flow of this file is as follows:
# 1. Download the checkpoint from Meta (https://llama.meta.com/llama-downloads/) in your local directory. Convert this PyTorch checkpoint into Orbax checkpoint format for use in MaxText. 
# 2. Run decoding, finetuning of Llama2-7b with this converted checkpoint. Also, run pretraining of Llama2-7b.
# 3. Run decoding from the finetuned weights.
# 4. Convert the scanned checkpoint from step #1 into unscanned checkpoint format and run more efficient decoding.


set -ex
idx=$(date +%Y-%m-%d-%H-%M)

# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run
export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
export DATASET_PATH=gs://maxtext-dataset
export ASYNC_CHECKPOINTING=false

# We install torch CPU because the checkpoint conversion script MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt does not need a TPU/GPU
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

# We define a var for the path to the Meta checkpoint. Non-Googlers please remember to update the source `META_CHECKPOINT_PATH` to the GCS bucket where you have your Meta checkpoint
export META_CHECKPOINT_PATH=gs://maxtext-llama/llama2-7b/meta-ckpt

# In the following command, we are copying Meta's checkpoint into a local directory `tmp`. 
# You can use a different local directory than /tmp/, if you do so, please use the same local path for `base-model-path` when running `python3 -m MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt`
gcloud storage cp -r ${META_CHECKPOINT_PATH} /tmp/

# `CONVERTED_CHECKPOINT_PATH` is the path to the GCS bucket where we want to save our converted (Orbax) checkpoint. Non-Googlers please remember to point `CONVERTED_CHECKPOINT_PATH` to a GCS bucket that you own
export CONVERTED_CHECKPOINT_PATH=gs://maxtext-llama/test/${idx}/decode-ckpt-maxtext

# Next, run the conversion script `MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt` to convert Meta's PyTorch checkpoint in `base-model-path` and save the new converted (Orbax) checkpoint in the `maxtext-model-path`
python3 -m MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt --base-model-path /tmp/meta-ckpt --model-size llama2-7b --maxtext-model-path ${CONVERTED_CHECKPOINT_PATH}

# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory exactly inside `CONVERTED_CHECKPOINT_PATH`. This way it is easier to use this path in the `train` and `decode` commands
export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items

# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
# We can do this by running `MaxText.generate_param_only_checkpoint` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. 
export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint_${idx}
python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' force_unroll=true

# Like before, we define `UNSCANNED_CKPT_PATH` to refer to the checkpoint subdirectory exactly
export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${DIRECT_PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items

# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint converted directly from Meta's PyTorch checkpoint aka `CONVERTED_CHECKPOINT`. Note that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=runner_decode_unscanned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4  max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false


# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml load_parameters_path=${CONVERTED_CHECKPOINT} run_name=runner_direct_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4  max_target_length=16 prompt="I love to" attention=dot_product

# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning
# We run this with context parallelism and the `ici_context_parallelism` flag as an integration test for context parallelism
python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml load_parameters_path=${CONVERTED_CHECKPOINT} run_name=runner_finetuning_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} async_checkpointing=${ASYNC_CHECKPOINTING} per_device_batch_size=1 model_name='llama2-7b' ici_context_parallelism=4 steps=10 per_device_batch_size=1 checkpoint_period=5 packing=false
# We also run pre-training of Llama2-7b, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from
# We run this with context parallelism and the `ici_context_parallelism` flag as an integration test for context parallelism
python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml run_name=runner_pretraining_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} async_checkpointing=${ASYNC_CHECKPOINTING} per_device_batch_size=1 model_name='llama2-7b' ici_context_parallelism=4 steps=10 per_device_batch_size=1 packing=false

# Now, run decoding on the checkpoint generated from our finetune run. Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. So, we can use the `MaxText.generate_param_only_checkpoint` to convert
# the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run, say the checkpoint saved at finetuning step #5
# Also, `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding
export PARAMETER_CHECKPOINT_RUN=generate_param_only_checkpoint_${idx}
python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/runner_finetuning_${idx}/checkpoints/5/items run_name=${PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' force_unroll=true

# Like before, we define `NEW_CKPT_PATH` to refer to the checkpoint subdirectory exactly
export NEW_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items

# We run decoding on the fine-tuned parameter checkpoint
python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${NEW_CKPT_PATH} run_name=runner_decode_finetuned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4  max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false

# We also test whether the forward pass logits match the golden logits for Llama2-7b
python3 -m tests.forward_pass_logit_checker  "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=llama2-7b ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic  dtype=float32 scan_layers=false --rtol=0.1 --atol=0.1

# Converting MaxText orbax checkpoint to HF
JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.llama_mistral_mixtral_orbax_to_hf "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=gs://runner-maxtext-logs load_parameters_path=${CONVERTED_CHECKPOINT} run_name=convert_to_hf model_name=llama2-7b hf_model_path=/tmp/hf_llama2

# Test whether the forward pass logits match the golden logits for Huggingface checkpoint converted from MaxText orbax checkpoint
# We run this with context parallelism and the `ici_context_parallelism` flag as an integration test for context parallelism
python3 -m tests.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_hf ici_context_parallelism=4 model_name=llama2-7b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=false --hf_model_path=/tmp/hf_llama2 --max_kl_div=1e-4
